{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FG7kC-qt4C99"
      },
      "source": [
        "# Load Head with id2label"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BDbnBJQ64ByT",
        "outputId": "f8cb8677-8e41-4788-acde-d30f12048cc6",
        "ExecuteTime": {
          "end_time": "2023-08-17T10:42:49.471980Z",
          "start_time": "2023-08-17T10:42:46.167021600Z"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m204.3/204.3 kB\u001b[0m \u001b[31m3.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m20.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m268.8/268.8 kB\u001b[0m \u001b[31m25.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.8/7.8 MB\u001b[0m \u001b[31m55.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[2K     \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.3/1.3 MB\u001b[0m \u001b[31m48.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
            "\u001b[?25h"
          ]
        }
      ],
      "source": [
        "!pip install -Uq adapters"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zd8fsYXX4uxm"
      },
      "source": [
        "First, we create the model and load the trained adapter and head. With the head, the labels, and the i2label mapping is automatically loaded. With `model.get_labels()` and `model.get_labels_dict()` you can access them."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 276,
          "referenced_widgets": [
            "3b73992217c048309d177cc6b3241e58",
            "c8a9b7d8337c445abc8f8585a029d549",
            "6f2e64b7b3e74c668d5ee925838d5b84",
            "bfede5717808409593e07143d12288d5",
            "6da7cc4c94e64f80909c8f9f5d8cfa49",
            "e865d30384f5419f892fef7652626919",
            "a9e0e563a515453586c592e1807e6868",
            "10f16a015858446295d936d21c67185e",
            "05f262dfc9bb451a93c3119fb0e53914",
            "0f379ddd6a4e41799a1adf6a4db1be96",
            "7fa8f3852b4c4fc3979c7d321366b84c",
            "8f8c43791bcb4dd1b177c2e28f5b0c1a",
            "799dce6076794898a2cd1ab05940db4d",
            "261ec051dfcc459a9d7f6d078f974a59",
            "fd681fb7f8b14c4ca0ebd5caf0c63784",
            "0e015268bc01441a84f21042b702cd51",
            "8acafdb959364f9090f4a145ced7aaea",
            "72892582b4ff4e7fb13b47deb2b8f9bd",
            "56182af714764b6f835075ae7b6ee0ad",
            "90c6f163df4044a096a60dbcab450d74",
            "7c7874c20b4b4a10974b3943f3ab23c8",
            "08b9956abd5b4e1da4f51f9bf618b265",
            "bde809ceadb844de932030dd4abee016",
            "a91140923b974c38ae3d7a4ecb167974",
            "e982c3ded3b546279de24c389f65a8f8",
            "1d9f694be75f4315975543df3d1b0d00",
            "e283910350a144649b939905ff85039c",
            "fb46966d4eb046cdb3841e2c69dd7b6a",
            "66a99206c3d54a3c8d6d828ec54b868a",
            "ffe52290c6b0423cb694826db94a960a",
            "b3d3c95c6acd470292c2e4bf2613d174",
            "7ff39d6308714ae3a04e63ba0b0b1369",
            "0d70f06573ea4c199115289964488bdf",
            "e3ba4e185ddc43baa01414de3ed3c3d3",
            "ed14f066dedc49719c06204327a42b70",
            "f159438511dc41d19ba0cdb2e8ec43ee",
            "5ce95504bcd349e884587751465bb567",
            "ffd38661258546f98e56eda18e294250",
            "653d89e3f81548dd8a75ed229c807768",
            "499887cfa6424365bd14c6ca1ca62c46",
            "7c111fbd10b140458d2c64b7a36f3f55",
            "eb5e0e09738c45eaaca0e937c5735910",
            "5435125ec35348eaaded999ade569be3",
            "b8612b32497f49e681bd2b83cd6c9bea",
            "d9a6cb83232b41ca84c88ed0688ede10",
            "28aba247cf7f4fdf8f1b89eea6c80b9a",
            "83532eaf4e424797902759a5fe172eb0",
            "84379d97278d456ebd1d5a813671efc3",
            "9d4640d933c04beda0e11076b8e0ada7",
            "93b93d04d7db41af86b1a64db53b3bc2",
            "8c9198fcedab445387161aa0a0369488",
            "ceb30b618e3443cdafd327db5612c214",
            "d2a294e736ff451b87ecff59217ca4a8",
            "2d2d0f1baed64ca384eaa606586c4699",
            "b5d389e9ad6e4352822e9823da187064",
            "f3cc9087e77345a08aeeee3eb3fd83ef",
            "a820779d9fa3441dba8f2826f902c702",
            "438f0cc9e65b4f39b095e1ea9dc1d6e9",
            "18386374f6d441bc88486af9b099d797",
            "42348345de1d49f7b097bb117fb5b9a9",
            "69ae6fd7513841cf918bea43dfd5b7b1",
            "cb4d186f354d4d6d8041023152d5528c",
            "3a9172ac6f78466e9ce67455d296375c",
            "57cbdb7f5c0b404bb339e6ab437c58f6",
            "05af67913d1a482581021212ea2c2ef4",
            "a427db91c95d4261ab8f369f19d40321",
            "d9edd2bf972e4d64b01c73d48dd36063",
            "3531a1c32d7342e7b76d959052aa6740",
            "76be829764d34d45b661670c33f9bc12",
            "a6b3ccfcaf3c4755ba677ddda4c1064b",
            "51de59c3e0f649aba3c4f41a765f7378",
            "75006ce26ced473b8e63b03ab987153a",
            "54ced2582f8b4f7ca18e729c0be56a66",
            "d1f5b64d9d824b2188368026aab6e12a",
            "83eb1f704d6e4ac0a53033668b7da350",
            "cef47823c90447ec88bd6ac53d2e1496",
            "b53aa9b03bee4b7ba5d1e661094cc984"
          ]
        },
        "id": "3uGJQZPz4QNz",
        "outputId": "cfa9ab4c-b0ba-42e8-c62b-435765c8e8a0",
        "ExecuteTime": {
          "end_time": "2023-08-17T10:42:56.807720900Z",
          "start_time": "2023-08-17T10:42:49.474973600Z"
        }
      },
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)lve/main/config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "3b73992217c048309d177cc6b3241e58"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "8f8c43791bcb4dd1b177c2e28f5b0c1a"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)rt-base-uncased.json:   0%|          | 0.00/690 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "bde809ceadb844de932030dd4abee016"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)nll2003_pfeiffer.zip:   0%|          | 0.00/3.36M [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "e3ba4e185ddc43baa01414de3ed3c3d3"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)okenizer_config.json:   0%|          | 0.00/28.0 [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d9a6cb83232b41ca84c88ed0688ede10"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "f3cc9087e77345a08aeeee3eb3fd83ef"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "display_data",
          "data": {
            "text/plain": [
              "Downloading (…)/main/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]"
            ],
            "application/vnd.jupyter.widget-view+json": {
              "version_major": 2,
              "version_minor": 0,
              "model_id": "d9edd2bf972e4d64b01c73d48dd36063"
            }
          },
          "metadata": {}
        },
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "['B-LOC', 'B-MISC', 'B-ORG', 'B-PER', 'I-LOC', 'I-MISC', 'I-ORG', 'I-PER', 'O']\n",
            "{1: 'B-LOC', 7: 'B-MISC', 5: 'B-ORG', 3: 'B-PER', 2: 'I-LOC', 8: 'I-MISC', 6: 'I-ORG', 4: 'I-PER', 0: 'O'}\n"
          ]
        }
      ],
      "source": [
        "from adapters import AutoAdapterModel\n",
        "from transformers import AutoTokenizer\n",
        "import numpy as np\n",
        "\n",
        "model_name = \"bert-base-uncased\"\n",
        "model = AutoAdapterModel.from_pretrained(model_name)\n",
        "adapter_name = model.load_adapter(\"ner/conll2003@ukp\")\n",
        "\n",
        "model.active_adapters = adapter_name\n",
        "model.active_head = \"ner_head\"\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
        "\n",
        "# How you can acces the labels and the mapping for a pretrained head\n",
        "print(model.get_labels())\n",
        "print(model.get_labels_dict())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "1-Ps9GKJ5ahe"
      },
      "source": [
        "This helper function allows us to get the sequence of ids of the predicted output for a specific input sentence."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "ylUZiWHs6Oay",
        "ExecuteTime": {
          "end_time": "2023-08-17T10:42:56.917699300Z",
          "start_time": "2023-08-17T10:42:56.812499800Z"
        }
      },
      "outputs": [],
      "source": [
        "def predict(sentence):\n",
        "  tokens = tokenizer.encode(\n",
        "        sentence,\n",
        "        return_tensors=\"pt\",\n",
        "    )\n",
        "  model.eval()\n",
        "  preds = model(tokens, adapter_names=['ner'])[0]\n",
        "  preds = preds.detach().numpy()\n",
        "  preds = np.argmax(preds, axis=2)\n",
        "  return tokenizer.tokenize(sentence), preds.squeeze()[1:-1]"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "WuW4eYx5WJE9"
      },
      "source": [
        "If we want to use the model to predict the labels of a sentence we can use the `model.get_labels_dict()` function to map the predicted label ids to the corresponding label, as for the example text."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-q1ErWcF5Z2w",
        "outputId": "16098d31-3aaf-42db-f1e7-4a4ae78f31ad",
        "ExecuteTime": {
          "end_time": "2023-08-17T10:42:57.077063800Z",
          "start_time": "2023-08-17T10:42:56.846516300Z"
        }
      },
      "outputs": [
        {
          "output_type": "stream",
          "name": "stdout",
          "text": [
            "germany(B-LOC) '(O) s(O) representative(O) to(O) the(O) european(B-ORG) union(I-ORG) '(O) s(O) veterinary(B-ORG) committee(I-ORG) werner(B-PER) z(I-PER) ##wing(I-PER) ##mann(I-PER) said(O) on(O) wednesday(O) consumers(O) should(O) buy(O) sheep(O) ##me(O) ##at(O) from(O) countries(O) other(O) than(O) britain(B-LOC) until(O) the(O) scientific(O) advice(O) was(O) clearer(O) .(O) "
          ]
        }
      ],
      "source": [
        "example_text=\"Germany's representative to the European Union\\'s veterinary committee Werner Zwingmann said on Wednesday consumers should buy sheepmeat from countries other than Britain until the scientific advice was clearer.\"\n",
        "# Get the mapping of ids to labels\n",
        "label_map = model.get_labels_dict()\n",
        "tokens, preds = predict(example_text)\n",
        "for token, pred in zip(tokens, preds):\n",
        "  print(f\"{token}({label_map[pred]}) \", end=\"\")"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.10.6"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
